{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "import tensorflow as tf\n",
    "import itertools\n",
    "import collections\n",
    "import re\n",
    "import random\n",
    "import sentencepiece as spm\n",
    "from tqdm import tqdm\n",
    "import xlnet_utils as squad_utils\n",
    "import xlnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.v9.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('/home/husein/xlnet/xlnet-squad-train.pkl', 'rb') as fopen:\n",
    "    train_features, train_examples = pickle.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 512\n",
    "doc_stride = 128\n",
    "max_query_length = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 5\n",
    "batch_size = 12\n",
    "warmup_proportion = 0.1\n",
    "n_best_size = 20\n",
    "num_train_steps = int(len(train_features) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "learning_rate = 2e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/alxlnet/xlnet.py:70: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kwargs = dict(\n",
    "      is_training=True,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.1,\n",
    "      dropatt=0.1,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(\n",
    "    json_path = 'alxlnet-base-2020-04-10/config.json'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_parameters = dict(\n",
    "      decay_method = 'poly',\n",
    "      train_steps = num_train_steps,\n",
    "      learning_rate = learning_rate,\n",
    "      warmup_steps = num_warmup_steps,\n",
    "      min_lr_ratio = 0.0,\n",
    "      weight_decay = 0.00,\n",
    "      adam_epsilon = 1e-8,\n",
    "      num_core_per_host = 1,\n",
    "      lr_layer_decay_rate = 1,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clip = 1.0,\n",
    "      clamp_len=-1,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter:\n",
    "    def __init__(self, decay_method, warmup_steps, weight_decay, adam_epsilon, \n",
    "                num_core_per_host, lr_layer_decay_rate, use_tpu, learning_rate, train_steps,\n",
    "                min_lr_ratio, clip, **kwargs):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "        \n",
    "training_parameters = Parameter(**training_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.contrib import layers as contrib_layers\n",
    "\n",
    "class Model:\n",
    "    def __init__(self, is_training = True):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        self.start_positions = tf.placeholder(tf.int32, [None])\n",
    "        self.end_positions = tf.placeholder(tf.int32, [None])\n",
    "        self.p_mask = tf.placeholder(tf.float32, [None, None])\n",
    "        self.is_impossible = tf.placeholder(tf.int32, [None])\n",
    "        self.cls_index = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        xlnet_model = xlnet.XLNetModel(\n",
    "            xlnet_config=xlnet_config,\n",
    "            run_config=xlnet_parameters,\n",
    "            input_ids=tf.transpose(self.X, [1, 0]),\n",
    "            seg_ids=tf.transpose(self.segment_ids, [1, 0]),\n",
    "            input_mask=tf.transpose(self.input_masks, [1, 0]))\n",
    "        \n",
    "        output = xlnet_model.get_sequence_output()\n",
    "        self.output = output\n",
    "        self.model = xlnet_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/alxlnet/xlnet.py:253: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/xlnet.py:253: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/custom_modeling.py:697: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/custom_modeling.py:704: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/custom_modeling.py:809: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/custom_modeling.py:109: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n"
     ]
    }
   ],
   "source": [
    "is_training = True\n",
    "\n",
    "tf.reset_default_graph()\n",
    "model = Model(is_training = is_training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_n_top = 5\n",
    "end_n_top = 5\n",
    "seq_len = tf.shape(model.X)[1]\n",
    "initializer = model.model.get_initializer()\n",
    "return_dict = {}\n",
    "p_mask = model.p_mask\n",
    "output = model.output\n",
    "cls_index = model.cls_index\n",
    "\n",
    "with tf.variable_scope('start_logits'):\n",
    "    start_logits = tf.layers.dense(\n",
    "        output, 1, kernel_initializer = initializer\n",
    "    )\n",
    "    start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])\n",
    "    start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask\n",
    "    start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)\n",
    "    \n",
    "with tf.variable_scope('end_logits'):\n",
    "    if is_training:\n",
    "        # during training, compute the end logits based on the\n",
    "        # ground truth of the start position\n",
    "\n",
    "        start_positions = tf.reshape(model.start_positions, [-1])\n",
    "        start_index = tf.one_hot(\n",
    "            start_positions, depth = seq_len, axis = -1, dtype = tf.float32\n",
    "        )\n",
    "        start_features = tf.einsum('lbh,bl->bh', output, start_index)\n",
    "        start_features = tf.tile(start_features[None], [seq_len, 1, 1])\n",
    "        end_logits = tf.layers.dense(\n",
    "            tf.concat([output, start_features], axis = -1),\n",
    "            xlnet_config.d_model,\n",
    "            kernel_initializer = initializer,\n",
    "            activation = tf.tanh,\n",
    "            name = 'dense_0',\n",
    "        )\n",
    "        end_logits = tf.contrib.layers.layer_norm(\n",
    "            end_logits, begin_norm_axis = -1\n",
    "        )\n",
    "\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_logits,\n",
    "            1,\n",
    "            kernel_initializer = initializer,\n",
    "            name = 'dense_1',\n",
    "        )\n",
    "        end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])\n",
    "        end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask\n",
    "        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)\n",
    "    else:\n",
    "        # during inference, compute the end logits based on beam search\n",
    "\n",
    "        start_top_log_probs, start_top_index = tf.nn.top_k(\n",
    "            start_log_probs, k = start_n_top\n",
    "        )\n",
    "        start_index = tf.one_hot(\n",
    "            start_top_index, depth = seq_len, axis = -1, dtype = tf.float32\n",
    "        )\n",
    "        start_features = tf.einsum('lbh,bkl->bkh', output, start_index)\n",
    "        end_input = tf.tile(\n",
    "            output[:, :, None], [1, 1, start_n_top, 1]\n",
    "        )\n",
    "        start_features = tf.tile(start_features[None], [seq_len, 1, 1, 1])\n",
    "        end_input = tf.concat([end_input, start_features], axis = -1)\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_input,\n",
    "            xlnet_config.d_model,\n",
    "            kernel_initializer = initializer,\n",
    "            activation = tf.tanh,\n",
    "            name = 'dense_0',\n",
    "        )\n",
    "        end_logits = tf.contrib.layers.layer_norm(\n",
    "            end_logits, begin_norm_axis = -1\n",
    "        )\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_logits,\n",
    "            1,\n",
    "            kernel_initializer = initializer,\n",
    "            name = 'dense_1',\n",
    "        )\n",
    "        end_logits = tf.reshape(\n",
    "            end_logits, [seq_len, -1, start_n_top]\n",
    "        )\n",
    "        end_logits = tf.transpose(end_logits, [1, 2, 0])\n",
    "        end_logits_masked = (\n",
    "            end_logits * (1 - p_mask[:, None]) - 1e30 * p_mask[:, None]\n",
    "        )\n",
    "        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)\n",
    "        end_top_log_probs, end_top_index = tf.nn.top_k(\n",
    "            end_log_probs, k = end_n_top\n",
    "        )\n",
    "        end_top_log_probs = tf.reshape(\n",
    "            end_top_log_probs, [-1, start_n_top * end_n_top]\n",
    "        )\n",
    "        end_top_index = tf.reshape(\n",
    "            end_top_index, [-1, start_n_top * end_n_top]\n",
    "        )\n",
    "\n",
    "if is_training:\n",
    "    return_dict['start_log_probs'] = start_log_probs\n",
    "    return_dict['end_log_probs'] = end_log_probs\n",
    "else:\n",
    "    return_dict['start_top_log_probs'] = start_top_log_probs\n",
    "    return_dict['start_top_index'] = start_top_index\n",
    "    return_dict['end_top_log_probs'] = end_top_log_probs\n",
    "    return_dict['end_top_index'] = end_top_index\n",
    "\n",
    "# an additional layer to predict answerability\n",
    "with tf.variable_scope('answer_class'):\n",
    "    # get the representation of CLS\n",
    "    cls_index = tf.one_hot(\n",
    "        cls_index, seq_len, axis = -1, dtype = tf.float32\n",
    "    )\n",
    "    cls_feature = tf.einsum('lbh,bl->bh', output, cls_index)\n",
    "\n",
    "    # get the representation of START\n",
    "    start_p = tf.nn.softmax(\n",
    "        start_logits_masked, axis = -1, name = 'softmax_start'\n",
    "    )\n",
    "    start_feature = tf.einsum('lbh,bl->bh', output, start_p)\n",
    "\n",
    "    # note(zhiliny): no dependency on end_feature so that we can obtain\n",
    "    # one single `cls_logits` for each sample\n",
    "    ans_feature = tf.concat([start_feature, cls_feature], -1)\n",
    "    ans_feature = tf.layers.dense(\n",
    "        ans_feature,\n",
    "        xlnet_config.d_model,\n",
    "        activation = tf.tanh,\n",
    "        kernel_initializer = initializer,\n",
    "        name = 'dense_0',\n",
    "    )\n",
    "    ans_feature = tf.layers.dropout(\n",
    "        ans_feature, 0.1, training = is_training\n",
    "    )\n",
    "    cls_logits = tf.layers.dense(\n",
    "        ans_feature,\n",
    "        1,\n",
    "        kernel_initializer = initializer,\n",
    "        name = 'dense_1',\n",
    "        use_bias = False,\n",
    "    )\n",
    "    cls_logits = tf.squeeze(cls_logits, -1)\n",
    "\n",
    "    return_dict['cls_logits'] = cls_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "seq_length = tf.shape(model.X)[1]\n",
    "\n",
    "def compute_loss(log_probs, positions):\n",
    "    one_hot_positions = tf.one_hot(\n",
    "        positions, depth = seq_length, dtype = tf.float32\n",
    "    )\n",
    "\n",
    "    loss = -tf.reduce_sum(one_hot_positions * log_probs, axis = -1)\n",
    "    loss = tf.reduce_mean(loss)\n",
    "    return loss\n",
    "\n",
    "start_loss = compute_loss(\n",
    "    return_dict['start_log_probs'], model.start_positions\n",
    ")\n",
    "end_loss = compute_loss(\n",
    "    return_dict['end_log_probs'], model.end_positions\n",
    ")\n",
    "\n",
    "total_loss = (start_loss + end_loss) * 0.5\n",
    "\n",
    "cls_logits = return_dict['cls_logits']\n",
    "is_impossible = tf.reshape(model.is_impossible, [-1])\n",
    "regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(\n",
    "    labels = tf.cast(is_impossible, dtype = tf.float32),\n",
    "    logits = cls_logits,\n",
    ")\n",
    "regression_loss = tf.reduce_mean(regression_loss)\n",
    "\n",
    "# note(zhiliny): by default multiply the loss by 0.5 so that the scale is\n",
    "# comparable to start_loss and end_loss\n",
    "total_loss += regression_loss * 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/alxlnet/model_utils.py:334: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/model_utils.py:105: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/model_utils.py:119: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/alxlnet/model_utils.py:150: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import model_utils\n",
    "\n",
    "optimizer, _, _ = model_utils.get_train_op(training_parameters, total_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "checkpoint = 'alxlnet-base-2020-04-10/model.ckpt-300000'\n",
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from alxlnet-base-2020-04-10/model.ckpt-300000\n"
     ]
    }
   ],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 10898/10898 [3:09:26<00:00,  1.04s/it, cost=0.684, end_loss=0.00015, regression_loss=0.714, start_loss=0.654]    "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n",
      "1.0763367\n",
      "1.291288\n",
      "0.45299488\n",
      "0.40839055\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for e in range(1):\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_features), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    costs, start_losses, end_losses, regression_losses = [], [], [], []\n",
    "    for i in pbar:\n",
    "        batch = train_features[i: i + batch_size]\n",
    "        batch_ids = [b.input_ids for b in batch]\n",
    "        batch_masks = [b.input_mask for b in batch]\n",
    "        batch_segment = [b.segment_ids for b in batch]\n",
    "        batch_start = [b.start_position for b in batch]\n",
    "        batch_end = [b.end_position for b in batch]\n",
    "        is_impossible = [b.is_impossible for b in batch]\n",
    "        p_mask = [b.p_mask for b in batch]\n",
    "        cls_index = [b.cls_index for b in batch]\n",
    "        cost, start_loss_, end_loss_, regression_loss_, _ = sess.run(\n",
    "            [total_loss, start_loss, end_loss, regression_loss, optimizer],\n",
    "            feed_dict = {\n",
    "                model.start_positions: batch_start,\n",
    "                model.end_positions: batch_end,\n",
    "                model.X: batch_ids,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks,\n",
    "                model.is_impossible: is_impossible,\n",
    "                model.p_mask: p_mask,\n",
    "                model.cls_index: cls_index\n",
    "            },\n",
    "        )\n",
    "        pbar.set_postfix(cost = cost, start_loss = start_loss_,\n",
    "                        end_loss = end_loss_, regression_loss = regression_loss_)\n",
    "        costs.append(cost)\n",
    "        start_losses.append(start_loss_)\n",
    "        end_losses.append(end_loss_)\n",
    "        regression_losses.append(regression_loss_)\n",
    "        \n",
    "    print(f'epoch: {e}')\n",
    "    print(np.mean(costs))\n",
    "    print(np.mean(start_losses))\n",
    "    print(np.mean(end_losses))\n",
    "    print(np.mean(regression_losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'alxlnet-base-squad/model.ckpt'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'alxlnet-base-squad/model.ckpt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
